{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "da6087fbd570"
      },
      "source": [
        "# Functional API"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d169f4a559d5"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>     <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/functional\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で表示</a>   </td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/functional.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab で実行</a>   </td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/functional.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示</a>   </td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/keras/functional.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a>   </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8d4ac441b1fc"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ec52be14e686"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "871fbb54ea07"
      },
      "source": [
        "## 前書き\n",
        "\n",
        "Keras *Functional API* は、`tf.keras.Sequential` API よりも柔軟なモデルの作成が可能で、非線形トポロジー、共有レイヤー、さらには複数の入力または出力を持つモデル処理することができます。\n",
        "\n",
        "これは、ディープラーニングのモデルは通常、レイヤーの有向非巡回グラフ（DAG）であるという考えに基づいてます。要するに、Functional API は*レイヤーのグラフ*を構築する方法です。\n",
        "\n",
        "次のモデルを考察してみましょう。\n",
        "\n",
        "```\n",
        "(input: 784-dimensional vectors)\n",
        "       ↧\n",
        "[Dense (64 units, relu activation)]\n",
        "       ↧\n",
        "[Dense (64 units, relu activation)]\n",
        "       ↧\n",
        "[Dense (10 units, softmax activation)]\n",
        "       ↧\n",
        "(output: logits of a probability distribution over 10 classes)\n",
        "```\n",
        "\n",
        "これは 3つ のレイヤーを持つ単純なグラフです。Functional API を使用してモデルを構築するために、まずは入力ノードを作成することから始めます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8d477c91955a"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(784,))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "13c14d993620"
      },
      "source": [
        "データの形状は、784次元のベクトルとして設定されます。各サンプルの形状のみを指定するため、バッチサイズは常に省略されます。\n",
        "\n",
        "例えば、`(32, 32, 3)`という形状の画像入力がある場合には、次を使用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e4732e8e279b"
      },
      "outputs": [],
      "source": [
        "# Just for demonstration purposes.\n",
        "img_inputs = keras.Input(shape=(32, 32, 3))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "971bf8b5588f"
      },
      "source": [
        "返される`inputs`には、モデルに供給する入力データの形状と`dtype`についての情報を含みます。形状は次のとおりです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ee96c179846a"
      },
      "outputs": [],
      "source": [
        "inputs.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "866eee86d63e"
      },
      "source": [
        "dtype は次のとおりです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "480be92067f3"
      },
      "outputs": [],
      "source": [
        "inputs.dtype"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6c93172cdfba"
      },
      "source": [
        "この`inputs`オブジェクトのレイヤーを呼び出して、レイヤーのグラフに新しいノードを作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b50da8b1c28d"
      },
      "outputs": [],
      "source": [
        "dense = layers.Dense(64, activation=\"relu\")\n",
        "x = dense(inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0f36afe42ff3"
      },
      "source": [
        "「レイヤー呼び出し」アクションは、「入力」から作成したこのレイヤーまで矢印を描くようなものです。`dense`レイヤーに入力を「渡して」、`x`を取得します。\n",
        "\n",
        "レイヤーのグラフにあと少しレイヤーを追加してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "463d5cd0c484"
      },
      "outputs": [],
      "source": [
        "x = layers.Dense(64, activation=\"relu\")(x)\n",
        "outputs = layers.Dense(10)(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e379f089b044"
      },
      "source": [
        "この時点で、レイヤーのグラフの入力と出力を指定することにより、`Model`を作成できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7820cc2209a6"
      },
      "outputs": [],
      "source": [
        "model = keras.Model(inputs=inputs, outputs=outputs, name=\"mnist_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9aa111852d3"
      },
      "source": [
        "モデルの概要がどのようなものか、確認しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4949ab8242e8"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "99ab8535d6c3"
      },
      "source": [
        "また、モデルをグラフとしてプロットすることも可能です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6872f1b1b8b8"
      },
      "outputs": [],
      "source": [
        "keras.utils.plot_model(model, \"my_first_model.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6d9880136879"
      },
      "source": [
        "そしてオプションで、プロットされたグラフに各レイヤーの入力形状と出力形状を表示します 。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aa14046d3388"
      },
      "outputs": [],
      "source": [
        "keras.utils.plot_model(model, \"my_first_model_with_shape_info.png\", show_shapes=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "71969f9c91bb"
      },
      "source": [
        "この図とコードはほぼ同じです。コードバージョンでは、接続矢印は呼び出し演算に置き換えられています。\n",
        "\n",
        "「レイヤーのグラフ」はディープラーニングモデルの直感的なメンタルイメージであり、Functional API はこのメンタルイメージを忠実に映すモデルを作成する方法です。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "775b997c8c28"
      },
      "source": [
        "## トレーニング、評価、推論\n",
        "\n",
        "Functional API を使用して構築されたモデルの学習、評価、推論は、`Sequential`モデルの場合とまったく同じように動作します。\n",
        "\n",
        "`Model` クラスにはトレーニングループ（`fit()` メソッド）と評価ループ（`evaluate()` メソッド）が組み込まれています。[これらのループをカスタマイズ](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit/)することで、教師あり学習を超えるトレーニングのルーチン（[GAN](/examples/generative/dcgan_overriding_train_step/) など）を簡単に実装することができます。\n",
        "\n",
        "ここでは、MNIST 画像データを読み込み、ベクトルに再形成し、（検証分割のパフォーマンスを監視しながら）データ上でモデルを当てはめ、その後、テストデータ上でモデルを評価します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e61366d54487"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
        "\n",
        "x_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n",
        "x_test = x_test.reshape(10000, 784).astype(\"float32\") / 255\n",
        "\n",
        "model.compile(\n",
        "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer=keras.optimizers.RMSprop(),\n",
        "    metrics=[\"accuracy\"],\n",
        ")\n",
        "\n",
        "history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2)\n",
        "\n",
        "test_scores = model.evaluate(x_test, y_test, verbose=2)\n",
        "print(\"Test loss:\", test_scores[0])\n",
        "print(\"Test accuracy:\", test_scores[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2e13d7168c86"
      },
      "source": [
        "さらに詳しくは[トレーニングと評価](https://www.tensorflow.org/guide/keras/train_and_evaluate/)ガイドをご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "26991ef4dbbb"
      },
      "source": [
        "## 保存とシリアライズ\n",
        "\n",
        "Functional API を使用して構築されたモデルの保存とシリアル化は、`Sequential`のモデルと同じように動作します。Functional モデルを保存する標準的な方法は、`model.save()`を呼び出して、モデル全体を単一のファイルとして保存します。モデルを作成したコードが使用できなくなったとしても、後でこのファイルから同じモデルを再作成することが可能です。\n",
        "\n",
        "保存されたファイルには次を含みます。\n",
        "\n",
        "- モデルのアーキテクチャ\n",
        "- モデルの重み値（トレーニングの間に学習された値）\n",
        "- ある場合は、モデルトレーニング構成（`compile` に渡される構成)\n",
        "- あれば、オプティマイザとその状態（中断した所からトレーニングを再開するため）"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7e5e48669225"
      },
      "outputs": [],
      "source": [
        "model.save(\"path_to_my_model\")\n",
        "del model\n",
        "# Recreate the exact same model purely from the file:\n",
        "model = keras.models.load_model(\"path_to_my_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cfe2a761139b"
      },
      "source": [
        "詳細は、モデル[シリアライゼーションと保存](https://www.tensorflow.org/guide/keras/save_and_serialize/)ガイドをご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b747517364a9"
      },
      "source": [
        "## レイヤー群の同じグラフを使用してマルチモデルを定義する\n",
        "\n",
        "Functional API では、モデルはレイヤー群のグラフでそれらの出入力を指定することにより生成されます。それはレイヤー群の単一のグラフ が複数のモデルを生成するために使用できることを意味しています。\n",
        "\n",
        "次の例では、レイヤー群の同じスタックを使用して 2 つのモデルのインスタンス化を行います。これらは画像入力を 16 次元ベクトルに変換する`encoder`モデルと、トレーニングのためのエンドツーエンド`autoencoder`モデルです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f9924d8c9ed3"
      },
      "outputs": [],
      "source": [
        "encoder_input = keras.Input(shape=(28, 28, 1), name=\"img\")\n",
        "x = layers.Conv2D(16, 3, activation=\"relu\")(encoder_input)\n",
        "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n",
        "x = layers.MaxPooling2D(3)(x)\n",
        "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n",
        "x = layers.Conv2D(16, 3, activation=\"relu\")(x)\n",
        "encoder_output = layers.GlobalMaxPooling2D()(x)\n",
        "\n",
        "encoder = keras.Model(encoder_input, encoder_output, name=\"encoder\")\n",
        "encoder.summary()\n",
        "\n",
        "x = layers.Reshape((4, 4, 1))(encoder_output)\n",
        "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n",
        "x = layers.Conv2DTranspose(32, 3, activation=\"relu\")(x)\n",
        "x = layers.UpSampling2D(3)(x)\n",
        "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n",
        "decoder_output = layers.Conv2DTranspose(1, 3, activation=\"relu\")(x)\n",
        "\n",
        "autoencoder = keras.Model(encoder_input, decoder_output, name=\"autoencoder\")\n",
        "autoencoder.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e87d4185b652"
      },
      "source": [
        "ここでは、デコーディングアーキテクチャはエンコーディングアーキテクチャに対して厳密に対称的であるため、出力形状は入力形状`(28, 28, 1)`と同じです。\n",
        "\n",
        "`Conv2D`レイヤーの反対は`Conv2DTranspose`レイヤーで、`MaxPooling2D`レイヤーの反対は`UpSampling2D`レイヤーです。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9c746c1a0b79"
      },
      "source": [
        "## レイヤー同様、全てのモデルは呼び出し可能\n",
        "\n",
        "任意のモデルを`Input`上、あるいはもう一つのレイヤーの出力上で呼び出すことによって、それがレイヤーであるかのように扱うことができます。モデルを呼び出すことにより、単にモデルのアーキテクチャを再利用しているのではなく、その重みも再利用していることになります。\n",
        "\n",
        "これを実際に見るために、エンコーダモデルとデコーダモデルを作成し、それらを 2 回の呼び出しに連鎖してオートエンコーダ―モデルを取得する、オートエンコーダの異なる例を次に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "862ac58e928b"
      },
      "outputs": [],
      "source": [
        "encoder_input = keras.Input(shape=(28, 28, 1), name=\"original_img\")\n",
        "x = layers.Conv2D(16, 3, activation=\"relu\")(encoder_input)\n",
        "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n",
        "x = layers.MaxPooling2D(3)(x)\n",
        "x = layers.Conv2D(32, 3, activation=\"relu\")(x)\n",
        "x = layers.Conv2D(16, 3, activation=\"relu\")(x)\n",
        "encoder_output = layers.GlobalMaxPooling2D()(x)\n",
        "\n",
        "encoder = keras.Model(encoder_input, encoder_output, name=\"encoder\")\n",
        "encoder.summary()\n",
        "\n",
        "decoder_input = keras.Input(shape=(16,), name=\"encoded_img\")\n",
        "x = layers.Reshape((4, 4, 1))(decoder_input)\n",
        "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n",
        "x = layers.Conv2DTranspose(32, 3, activation=\"relu\")(x)\n",
        "x = layers.UpSampling2D(3)(x)\n",
        "x = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\n",
        "decoder_output = layers.Conv2DTranspose(1, 3, activation=\"relu\")(x)\n",
        "\n",
        "decoder = keras.Model(decoder_input, decoder_output, name=\"decoder\")\n",
        "decoder.summary()\n",
        "\n",
        "autoencoder_input = keras.Input(shape=(28, 28, 1), name=\"img\")\n",
        "encoded_img = encoder(autoencoder_input)\n",
        "decoded_img = decoder(encoded_img)\n",
        "autoencoder = keras.Model(autoencoder_input, decoded_img, name=\"autoencoder\")\n",
        "autoencoder.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0f77623d9cd5"
      },
      "source": [
        "ご覧のように、モデルはネストすることができ、(モデルはちょうどレイヤーのようなものであるため）サブモデルを含むことができます。モデル・ネスティングのための一般的なユースケースは *アンサンブル* です。モデルのセットを (それらの予測を平均する) 単一のモデルにアンサンブルする方法の例を次に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3bb36b630e5d"
      },
      "outputs": [],
      "source": [
        "def get_model():\n",
        "    inputs = keras.Input(shape=(128,))\n",
        "    outputs = layers.Dense(1)(inputs)\n",
        "    return keras.Model(inputs, outputs)\n",
        "\n",
        "\n",
        "model1 = get_model()\n",
        "model2 = get_model()\n",
        "model3 = get_model()\n",
        "\n",
        "inputs = keras.Input(shape=(128,))\n",
        "y1 = model1(inputs)\n",
        "y2 = model2(inputs)\n",
        "y3 = model3(inputs)\n",
        "outputs = layers.average([y1, y2, y3])\n",
        "ensemble_model = keras.Model(inputs=inputs, outputs=outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "447a319b73a6"
      },
      "source": [
        "## 複雑なグラフトポロジーを操作する\n",
        "\n",
        "### マルチ入力と出力を持つモデル\n",
        "\n",
        "Functional API はマルチ入力と出力の操作を容易にします。これは `Sequential` API では処理できません。\n",
        "\n",
        "たとえば、顧客が発行したチケットを優先度別にランク付けし、正しい部門にルーティングするシステムを構築する場合、モデルには次の 3 つの入力があります。\n",
        "\n",
        "- チケットの件名 （テキスト入力）\n",
        "- チケットの本文（テキスト入力）\n",
        "- ユーザーが追加した任意のタグ（カテゴリ入力）\n",
        "\n",
        "このモデルには 2 つの出力があります。\n",
        "\n",
        "- 0 と 1 の間のプライオリティスコア（スカラーシグモイド出力）\n",
        "- チケットを処理すべき部門（部門集合に渡るソフトマックス出力）\n",
        "\n",
        "このモデルは Functional API を使用すると数行で構築が可能です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "49009e53da63"
      },
      "outputs": [],
      "source": [
        "num_tags = 12  # Number of unique issue tags\n",
        "num_words = 10000  # Size of vocabulary obtained when preprocessing text data\n",
        "num_departments = 4  # Number of departments for predictions\n",
        "\n",
        "title_input = keras.Input(\n",
        "    shape=(None,), name=\"title\"\n",
        ")  # Variable-length sequence of ints\n",
        "body_input = keras.Input(shape=(None,), name=\"body\")  # Variable-length sequence of ints\n",
        "tags_input = keras.Input(\n",
        "    shape=(num_tags,), name=\"tags\"\n",
        ")  # Binary vectors of size `num_tags`\n",
        "\n",
        "# Embed each word in the title into a 64-dimensional vector\n",
        "title_features = layers.Embedding(num_words, 64)(title_input)\n",
        "# Embed each word in the text into a 64-dimensional vector\n",
        "body_features = layers.Embedding(num_words, 64)(body_input)\n",
        "\n",
        "# Reduce sequence of embedded words in the title into a single 128-dimensional vector\n",
        "title_features = layers.LSTM(128)(title_features)\n",
        "# Reduce sequence of embedded words in the body into a single 32-dimensional vector\n",
        "body_features = layers.LSTM(32)(body_features)\n",
        "\n",
        "# Merge all available features into a single large vector via concatenation\n",
        "x = layers.concatenate([title_features, body_features, tags_input])\n",
        "\n",
        "# Stick a logistic regression for priority prediction on top of the features\n",
        "priority_pred = layers.Dense(1, name=\"priority\")(x)\n",
        "# Stick a department classifier on top of the features\n",
        "department_pred = layers.Dense(num_departments, name=\"department\")(x)\n",
        "\n",
        "# Instantiate an end-to-end model predicting both priority and department\n",
        "model = keras.Model(\n",
        "    inputs=[title_input, body_input, tags_input],\n",
        "    outputs=[priority_pred, department_pred],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ee2735b3eff1"
      },
      "source": [
        "では、モデルをプロットします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "52c4dc6fd93e"
      },
      "outputs": [],
      "source": [
        "keras.utils.plot_model(model, \"multi_input_and_output_model.png\", show_shapes=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "907c119d04a4"
      },
      "source": [
        "このモデルをコンパイルする時に、各出力に異なる損失を割り当てることができます。また、各損失に異なる重みを割り当てて、トレーニング損失全体へのそれらの寄与をモジュール化することも可能です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3e1acef07668"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=keras.optimizers.RMSprop(1e-3),\n",
        "    loss=[\n",
        "        keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "        keras.losses.CategoricalCrossentropy(from_logits=True),\n",
        "    ],\n",
        "    loss_weights=[1.0, 0.2],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e780b99b133a"
      },
      "source": [
        "出力レイヤーの名前が異なるため、次のように損失を指定することも可能です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "84ce034debf2"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=keras.optimizers.RMSprop(1e-3),\n",
        "    loss={\n",
        "        \"priority\": keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "        \"department\": keras.losses.CategoricalCrossentropy(from_logits=True),\n",
        "    },\n",
        "    loss_weights=[1.0, 0.2],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "845b20ca3c9d"
      },
      "source": [
        "入力とターゲットの NumPy 配列のリストを渡し、モデルをトレーニングします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ae5ff9364b19"
      },
      "outputs": [],
      "source": [
        "# Dummy input data\n",
        "title_data = np.random.randint(num_words, size=(1280, 10))\n",
        "body_data = np.random.randint(num_words, size=(1280, 100))\n",
        "tags_data = np.random.randint(2, size=(1280, num_tags)).astype(\"float32\")\n",
        "\n",
        "# Dummy target data\n",
        "priority_targets = np.random.random(size=(1280, 1))\n",
        "dept_targets = np.random.randint(2, size=(1280, num_departments))\n",
        "\n",
        "model.fit(\n",
        "    {\"title\": title_data, \"body\": body_data, \"tags\": tags_data},\n",
        "    {\"priority\": priority_targets, \"department\": dept_targets},\n",
        "    epochs=2,\n",
        "    batch_size=32,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3c87f1fbe7aa"
      },
      "source": [
        "`Dataset`オブジェクトで fit を呼び出す時、それは`([title_data, body_data, tags_data], [priority_targets, dept_targets])`などのリストのタプル、または`({'title': title_data, 'body': body_data, 'tags': tags_data}、{'priority': priority_targets, 'department': dept_targets})`などのディクショナリのタプルを yield する必要があります。\n",
        "\n",
        "さらに詳しい説明については、[トレーニングと評価](https://www.tensorflow.org/guide/keras/train_and_evaluate/)ガイドをご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "64ada3f80484"
      },
      "source": [
        "### トイ ResNet モデル\n",
        "\n",
        "複数の入力と出力を持つモデルに加えて、Functional API では非線形接続トポロジー、つまりシーケンシャルに接続されていないレイヤーを持つモデルの操作を容易にします。これは`Sequential` API では扱うことができません。\n",
        "\n",
        "これの一般的なユースケースは、残差接続です。これを実証するために、CIFAR10 向けのトイ ResNet モデルを構築してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bfa8b7503813"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(32, 32, 3), name=\"img\")\n",
        "x = layers.Conv2D(32, 3, activation=\"relu\")(inputs)\n",
        "x = layers.Conv2D(64, 3, activation=\"relu\")(x)\n",
        "block_1_output = layers.MaxPooling2D(3)(x)\n",
        "\n",
        "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(block_1_output)\n",
        "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\n",
        "block_2_output = layers.add([x, block_1_output])\n",
        "\n",
        "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(block_2_output)\n",
        "x = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\n",
        "block_3_output = layers.add([x, block_2_output])\n",
        "\n",
        "x = layers.Conv2D(64, 3, activation=\"relu\")(block_3_output)\n",
        "x = layers.GlobalAveragePooling2D()(x)\n",
        "x = layers.Dense(256, activation=\"relu\")(x)\n",
        "x = layers.Dropout(0.5)(x)\n",
        "outputs = layers.Dense(10)(x)\n",
        "\n",
        "model = keras.Model(inputs, outputs, name=\"toy_resnet\")\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "05aefc66c54f"
      },
      "source": [
        "モデルをプロットします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ef7ac19c83be"
      },
      "outputs": [],
      "source": [
        "keras.utils.plot_model(model, \"mini_resnet.png\", show_shapes=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4f0883eae520"
      },
      "source": [
        "モデルをトレーニングします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4e1c7b530071"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
        "\n",
        "x_train = x_train.astype(\"float32\") / 255.0\n",
        "x_test = x_test.astype(\"float32\") / 255.0\n",
        "y_train = keras.utils.to_categorical(y_train, 10)\n",
        "y_test = keras.utils.to_categorical(y_test, 10)\n",
        "\n",
        "model.compile(\n",
        "    optimizer=keras.optimizers.RMSprop(1e-3),\n",
        "    loss=keras.losses.CategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[\"acc\"],\n",
        ")\n",
        "# We restrict the data to the first 1000 samples so as to limit execution time\n",
        "# on Colab. Try to train on the entire dataset until convergence!\n",
        "model.fit(x_train[:1000], y_train[:1000], batch_size=64, epochs=1, validation_split=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e7f35a9a1061"
      },
      "source": [
        "## レイヤーを共有する\n",
        "\n",
        "Functional API のもう 1 つの良い使い方は、*共有レイヤー*を使用するモデルです。共有レイヤーは、同じモデルで複数回再利用されるレイヤーインスタンスのことで、レイヤーグラフ内の複数のパスに対応するフィーチャを学習します。\n",
        "\n",
        "共有レイヤーは、似たような空間からの入力（例えば、似た語彙を特徴とする 2 つの異なるテキスト）をエンコードするためにしばしば使用されます。これにより、それら異なる入力間での情報の共有を可能になり、より少ないデータでそのようなモデルをトレーニングすることが可能になります。与えられた単語が入力のいずれかに見られる場合、それは共有レイヤーを通過する全ての入力の処理に有用です。\n",
        "\n",
        "Functional API でレイヤーを共有するには、同じレイヤーインスタンスを複数回呼び出します。例えば、2 つの異なるテキスト入力間で共有される `Embedding` レイヤを以下に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4b8e6a4f3e88"
      },
      "outputs": [],
      "source": [
        "# Embedding for 1000 unique words mapped to 128-dimensional vectors\n",
        "shared_embedding = layers.Embedding(1000, 128)\n",
        "\n",
        "# Variable-length sequence of integers\n",
        "text_input_a = keras.Input(shape=(None,), dtype=\"int32\")\n",
        "\n",
        "# Variable-length sequence of integers\n",
        "text_input_b = keras.Input(shape=(None,), dtype=\"int32\")\n",
        "\n",
        "# Reuse the same layer to encode both inputs\n",
        "encoded_input_a = shared_embedding(text_input_a)\n",
        "encoded_input_b = shared_embedding(text_input_b)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b4f193a74581"
      },
      "source": [
        "## レイヤーのグラフのノードを抽出して再利用する\n",
        "\n",
        "操作しているレイヤーのグラフは静的なデータ構造であるため、アクセスして検査をすることができます。そして、これが関数型モデルを画像としてプロットする方法でもあります。\n",
        "\n",
        "これはまた、中間レイヤー（グラフ内の「ノード」）のアクティブ化にアクセスが可能で、他の場所で再利用できることを意味します。これは特徴抽出などに非常に便利です。\n",
        "\n",
        "例を見てみましょう。これは ImageNet 上で事前トレーニングされた、重みを持つ VGG19 モデルです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8bdaa209ccbe"
      },
      "outputs": [],
      "source": [
        "vgg19 = tf.keras.applications.VGG19()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "874ef4b4de49"
      },
      "source": [
        "そしてこれらはグラフデータ構造をクエリして得られる、モデルの中間的なアクティブ化です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "391817839937"
      },
      "outputs": [],
      "source": [
        "features_list = [layer.output for layer in vgg19.layers]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e91a9dc2f5b0"
      },
      "source": [
        "これらの機能を使用して、中間レイヤーのアクティブ化の値を返す新しい特徴抽出モデルを作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "36a450517b63"
      },
      "outputs": [],
      "source": [
        "feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n",
        "\n",
        "img = np.random.random((1, 224, 224, 3)).astype(\"float32\")\n",
        "extracted_features = feat_extraction_model(img)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f2ac248fe202"
      },
      "source": [
        "これは特に、[ニューラルスタイル転送](https://keras.io/examples/generative/neural_style_transfer/)などのタスクに有用です。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c894ba891064"
      },
      "source": [
        "## カスタム層を使用して API を拡張する\n",
        "\n",
        "`tf.keras`には、例えば、次のような幅広い組み込みレイヤーが含まれています。\n",
        "\n",
        "- 畳み込みレイヤー : `Conv1D`、`Conv2D`、`Conv3D`、`Conv2DTranspose`\n",
        "- Pooling レイヤー : `MaxPooling1D`、`MaxPooling2D`、`MaxPooling3D`、`AveragePooling1D`\n",
        "- RNN レイヤー : `GRU`、`LSTM`、`ConvLSTM2D`\n",
        "- `BatchNormalization`、`Dropout`、`Embedding`、など\n",
        "\n",
        "必要なものが見つからない場合は、独自のレイヤーを作成して容易に API を拡張することができます。すべてのレイヤーは`Layer`クラスをサブクラス化して実装します。\n",
        "\n",
        "- `call`メソッドは、レイヤーが行う計算を指定します。\n",
        "- `build`メソッドは、レイヤーの重みを作成します（`__init__`でも重みを作成できるため、これは単なるスタイル慣習です）。\n",
        "\n",
        "レイヤーの新規作成に関する詳細については、[カスタムレイヤーとモデル](https://www.tensorflow.org/guide/keras/custom_layers_and_models)ガイドをご覧ください。\n",
        "\n",
        "`tf.keras.layers.Dense`の基本的な実装を以下に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1d9faf1f622a"
      },
      "outputs": [],
      "source": [
        "class CustomDense(layers.Layer):\n",
        "    def __init__(self, units=32):\n",
        "        super(CustomDense, self).__init__()\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "\n",
        "inputs = keras.Input((4,))\n",
        "outputs = CustomDense(10)(inputs)\n",
        "\n",
        "model = keras.Model(inputs, outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b8933568358c"
      },
      "source": [
        "カスタムレイヤーでシリアル化をサポートするには、レイヤーインスタンスのコンストラクタ引数を返す `get_config` メソッドを定義します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b22a134918a2"
      },
      "outputs": [],
      "source": [
        "class CustomDense(layers.Layer):\n",
        "    def __init__(self, units=32):\n",
        "        super(CustomDense, self).__init__()\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "    def get_config(self):\n",
        "        return {\"units\": self.units}\n",
        "\n",
        "\n",
        "inputs = keras.Input((4,))\n",
        "outputs = CustomDense(10)(inputs)\n",
        "\n",
        "model = keras.Model(inputs, outputs)\n",
        "config = model.get_config()\n",
        "\n",
        "new_model = keras.Model.from_config(config, custom_objects={\"CustomDense\": CustomDense})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "015abf7d0508"
      },
      "source": [
        "オプションで、config ディクショナリが与えられたレイヤーインスタンスを再作成する際に使用されたクラスメソッド `from_config(cls, config)` を実装します。デフォルトの `from_config` の実装は以下の通りです。\n",
        "\n",
        "```python\n",
        "def from_config(cls, config):\n",
        "  return cls(**config)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b4ead34e01dd"
      },
      "source": [
        "## いつ Functional API を使用するか\n",
        "\n",
        "新しいモデルを作成する場合または `Model`クラスを直接サブクラス化する場合に、Keras Functional API を使用する必要があるのでしょうか？ 一般的に Functional API はより高レベル、より容易かつ安全で、サブクラス化されたモデルがサポートしない多くの特徴を持っています。\n",
        "\n",
        "ただし、レイヤーの有向非巡回グラフ（DAG）として容易に表現できないモデルを構築する場合には、モデルのサブクラス化がより大きな柔軟性を与えます。例えば、Functional API では Tree-RNN を実装できず、`Model`を直接サブクラス化する必要があります。\n",
        "\n",
        "Functional API とモデルのサブクラス化の違いに関する詳細については、[TensorFlow 2.0 における Symbolic API と Imperative API とは？](https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html)をご覧ください。\n",
        "\n",
        "### Functional API の長所 :\n",
        "\n",
        "以下のプロパティは、（データ構造体でもある）Sequential モデルには真であり、（Python のバイトコードであり、データ構造体ではない）サブクラス化されたモデルには真ではありません。\n",
        "\n",
        "#### 低い冗長性\n",
        "\n",
        "`super(MyClass, self).__init__(...)`、`def call(self, ...):`などがありません。\n",
        "\n",
        "比較しよう :\n",
        "\n",
        "```python\n",
        "inputs = keras.Input(shape=(32,))\n",
        "x = layers.Dense(64, activation='relu')(inputs)\n",
        "outputs = layers.Dense(10)(x)\n",
        "mlp = keras.Model(inputs, outputs)\n",
        "```\n",
        "\n",
        "サブクラス化されたバージョンと比べます。\n",
        "\n",
        "```python\n",
        "class MLP(keras.Model):\n",
        "\n",
        "  def __init__(self, **kwargs):\n",
        "    super(MLP, self).__init__(**kwargs)\n",
        "    self.dense_1 = layers.Dense(64, activation='relu')\n",
        "    self.dense_2 = layers.Dense(10)\n",
        "\n",
        "  def call(self, inputs):\n",
        "    x = self.dense_1(inputs)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# Instantiate the model.\n",
        "mlp = MLP()\n",
        "# Necessary to create the model's state.\n",
        "# The model doesn't have a state until it's called at least once.\n",
        "_ = mlp(tf.zeros((1, 32)))\n",
        "```\n",
        "\n",
        "#### 連結グラフを定義しながらモデルを検証する\n",
        "\n",
        "Functional API では、入力仕様（形状とdtype）が（`Input`を使用して）あらかじめ作成されています。レイヤーを呼び出すたびに、レイヤーは渡された仕様が想定と一致しているかどうかをチェックし、一致していない場合には有用なエラーメッセージを表示します。\n",
        "\n",
        "これにより、Functional API を使用して構築できるモデルは全て確実に実行されます。収束関連のデバッグ以外の全てのデバッグは、実行時ではなく、モデル構築中に静的に行われます。これはコンパイラの型チェックに類似しています。\n",
        "\n",
        "#### 関数型モデルはプロット可能かつ検査可能です\n",
        "\n",
        "モデルをグラフとしてプロットすることが可能で、このグラフの中間ノードに簡単にアクセスすることができます。例えば（前の例で示したように）中間レイヤーのアクティブ化を抽出して再利用します。\n",
        "\n",
        "```python\n",
        "features_list = [layer.output for layer in vgg19.layers]\n",
        "feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n",
        "```\n",
        "\n",
        "#### 関数型モデルは、シリアル化やクローン化が可能です\n",
        "\n",
        "関数型モデルはコードの一部ではなくデータ構造であるため、安全なシリアル化が可能です。また、単一のファイルとして保存できるため、元のコードにアクセスすることなく全く同じモデルを再作成することができます。詳細は[シリアル化と保存に関するガイド](https://www.tensorflow.org/guide/keras/save_and_serialize/)をご覧ください。\n",
        "\n",
        "サブクラス化されたモデルをシリアライズするには、実装者がモデルレベルで`get_config()`および`from_config()`メソッドを指定する必要があります。\n",
        "\n",
        "### Functional API の弱点 :\n",
        "\n",
        "#### 動的アーキテクチャをサポートしません\n",
        "\n",
        "Functional API は、モデルをレイヤーの DAG として扱います。これはほとんどのディープラーニングアーキテクチャでは真ですが、必ずしも全てのアーキテクチャに該当するわけではありません。例えば、再帰的ネットワークや Tree RNN はこの想定に従わないため、Functional API では実装できません。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "72992d4ed462"
      },
      "source": [
        "## 異なる API スタイルをうまく組み合わせる\n",
        "\n",
        "Functional API とモデルのサブクラス化のいずれかを選択することは、モデルの 1 つのカテゴリに制限する二者択一ではありません。`tf.keras` API 内の全てのモデルは、それらが`Sequential` モデルでも、関数型モデルでも、新規に書かれたサブクラス化されたモデルであっても、お互いに相互作用することができます。\n",
        "\n",
        "関数型モデルや`Sequential`モデルは、サブクラス化されたモデルやレイヤーの一部として常に使用することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3c6221508766"
      },
      "outputs": [],
      "source": [
        "units = 32\n",
        "timesteps = 10\n",
        "input_dim = 5\n",
        "\n",
        "# Define a Functional model\n",
        "inputs = keras.Input((None, units))\n",
        "x = layers.GlobalAveragePooling1D()(inputs)\n",
        "outputs = layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "\n",
        "class CustomRNN(layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(CustomRNN, self).__init__()\n",
        "        self.units = units\n",
        "        self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n",
        "        self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n",
        "        # Our previously-defined Functional model\n",
        "        self.classifier = model\n",
        "\n",
        "    def call(self, inputs):\n",
        "        outputs = []\n",
        "        state = tf.zeros(shape=(inputs.shape[0], self.units))\n",
        "        for t in range(inputs.shape[1]):\n",
        "            x = inputs[:, t, :]\n",
        "            h = self.projection_1(x)\n",
        "            y = h + self.projection_2(state)\n",
        "            state = y\n",
        "            outputs.append(y)\n",
        "        features = tf.stack(outputs, axis=1)\n",
        "        print(features.shape)\n",
        "        return self.classifier(features)\n",
        "\n",
        "\n",
        "rnn_model = CustomRNN()\n",
        "_ = rnn_model(tf.zeros((1, timesteps, input_dim)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "41f42eb2a9c0"
      },
      "source": [
        "次のいずれかのパターンに従った`call`メソッドを実装していれば、Functional API で任意のサブクラス化されたレイヤーやモデルを使用することができます。\n",
        "\n",
        "- `call(self, inputs, **kwargs)` -- ここでいう`inputs`は、テンソルまたはテンソルのネストされた構造（テンソルのリストなど）であり、`**kwargs`は非テンソルの引数（非 inputs）です。\n",
        "- `call(self, inputs, training=None, **kwargs)` -- この`training` は、レイヤーがトレーニングモードと推論モードで振る舞うべきかどうかを示すブールです。\n",
        "- `call(self, inputs, mask=None, **kwargs)` -- この`mask`は、ブールマスクテンソルです。（例えば RNNに便利です。）\n",
        "- `call(self, inputs, training=None, mask=None, **kwargs)` -- もちろん、マスキングとトレーニング固有の動作の両方を同時に持つことができます。\n",
        "\n",
        "さらに、カスタムレイヤーやモデルで`get_config`メソッドを実装する場合、作成した関数型モデルは依然としてシリアル化やクローン化が可能です。\n",
        "\n",
        "新規に書かれたカスタム RNN を関数型モデルで使用する簡単な例を以下に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3deb90222d05"
      },
      "outputs": [],
      "source": [
        "units = 32\n",
        "timesteps = 10\n",
        "input_dim = 5\n",
        "batch_size = 16\n",
        "\n",
        "\n",
        "class CustomRNN(layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(CustomRNN, self).__init__()\n",
        "        self.units = units\n",
        "        self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n",
        "        self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n",
        "        self.classifier = layers.Dense(1)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        outputs = []\n",
        "        state = tf.zeros(shape=(inputs.shape[0], self.units))\n",
        "        for t in range(inputs.shape[1]):\n",
        "            x = inputs[:, t, :]\n",
        "            h = self.projection_1(x)\n",
        "            y = h + self.projection_2(state)\n",
        "            state = y\n",
        "            outputs.append(y)\n",
        "        features = tf.stack(outputs, axis=1)\n",
        "        return self.classifier(features)\n",
        "\n",
        "\n",
        "# Note that you specify a static batch size for the inputs with the `batch_shape`\n",
        "# arg, because the inner computation of `CustomRNN` requires a static batch size\n",
        "# (when you create the `state` zeros tensor).\n",
        "inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))\n",
        "x = layers.Conv1D(32, 3)(inputs)\n",
        "outputs = CustomRNN()(x)\n",
        "\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "rnn_model = CustomRNN()\n",
        "_ = rnn_model(tf.zeros((1, 10, 5)))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "functional.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
